import  torch


class Angular_Spectrum(torch.nn.Module):
    def __init__(self, wl, N_pixels, pixel_size, init_distance):
        super(Angular_Spectrum, self).__init__()
        self.pi = float(torch.pi)
        self.distance = init_distance
        kx = torch.linspace(-1, 1 - 2 / N_pixels, N_pixels, dtype=torch.float32) * (self.pi / pixel_size)
        ky = torch.linspace(-1, 1 - 2 / N_pixels, N_pixels, dtype=torch.float32) * (self.pi / pixel_size)
        KX, KY = torch.meshgrid(kx, ky)
        KX = KX.t()
        KY = KY.t()
        kz = -1j * self.pi * wl * ((KX / 2 / self.pi) ** 2 + (KY / 2 / self.pi) ** 2)
        self.kz=kz

    def forward(self, x):
        inputFT = torch.fft.fftshift(torch.fft.fft2(x))
        H = torch.exp(self.kz.to(x.device) * self.distance).to(x.device)
        output = torch.fft.ifft2(torch.fft.fftshift(inputFT * H)).to(x.device)

        return output